import copy
import os
from typing import Any, Dict, Union

from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

DEFAULT_TEXT_MODEL_NAME = "jhu-clsp/ettin-encoder-150m"
DEFAULT_VISION_MODEL_NAME = "google/siglip2-base-patch16-512"


def collect_arg_in_candidates(config, candidates, default=None) -> Any:
    """Gets the first available argument in a config given a list of candidate names."""
    for c in candidates:
        if hasattr(config, c):
            return getattr(config, c)
        elif c in config:
            return config[c]
    if default is not None:
        return default
    raise ValueError(f"No matching arguments found in candidates. Candidates: {candidates}, Config: {config}")


class ModernVBertTextConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ModernBERT`].
        It is used to instantiate an ModernBERT
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the
        [jhu-clsp/ettin-encoder-150m](https://huggingface.co/jhu-clsp/ettin-encoder-150m) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """

    model_type = "modernvbert_text"

    def __init__(
        self,
        text_model_name=DEFAULT_TEXT_MODEL_NAME,
        hidden_size=768,
        num_hidden_layers=22,
        intermediate_size=1152,
        mlp_bias=False,
        vocab_size=50368,
        **kwargs,
    ):
        super().__init__(
            text_model_name=text_model_name,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            mlp_bias=mlp_bias,
            vocab_size=vocab_size,
            **kwargs,
        )

    @classmethod
    def from_base_model(
        cls,
        text_model_name=DEFAULT_TEXT_MODEL_NAME,
        **kwargs,
    ):
        text_config = AutoConfig.from_pretrained(text_model_name, trust_remote_code=True)
        if hasattr(text_config, "text_config"):
            text_config = text_config.text_config

        hidden_size = collect_arg_in_candidates(text_config, ["hidden_size", "embed_dim"])
        num_hidden_layers = collect_arg_in_candidates(text_config, ["num_hidden_layers", "num_hidden_blocks"])
        intermediate_size = collect_arg_in_candidates(text_config, ["intermediate_size", "mlp_dim"])
        mlp_bias = collect_arg_in_candidates(text_config, ["mlp_bias", "mlp_hidden_bias"], default=False)
        vocab_size = collect_arg_in_candidates(text_config, ["vocab_size"])

        return cls(
            text_model_name=text_model_name,
            hidden_size=hidden_size,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            mlp_bias=mlp_bias,
            vocab_size=vocab_size,
            **kwargs,
        )


class ModernVBertVisionConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`SigLIP`]. It is used to instantiate
        the vision encoder part of the ModernVBERT.
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the SigLIP.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    """

    model_type = "modernvbert_vision"

    attribute_map = {
        "hidden_size": "embed_dim",
    }

    def __init__(
        self,
        vision_model_name=DEFAULT_VISION_MODEL_NAME,
        embed_dim=768,
        image_size=512,
        patch_size=16,
        num_hidden_layers=12,
        intermediate_size=3072,
        **kwargs,
    ):
        super().__init__(
            vision_model_name=vision_model_name,
            embed_dim=embed_dim,
            image_size=image_size,
            patch_size=patch_size,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            **kwargs,
        )

    @classmethod
    def from_base_model(
        cls,
        vision_model_name=DEFAULT_VISION_MODEL_NAME,
        **kwargs,
    ):
        vision_config = AutoConfig.from_pretrained(vision_model_name, trust_remote_code=True)
        if hasattr(vision_config, "vision_config"):
            vision_config = vision_config.vision_config

        embed_dim = collect_arg_in_candidates(vision_config, ["embed_dim", "hidden_size"])
        image_size = collect_arg_in_candidates(vision_config, ["image_size", "img_size"])
        patch_size = collect_arg_in_candidates(vision_config, ["patch_size"])
        num_hidden_layers = collect_arg_in_candidates(vision_config, ["num_hidden_layers", "num_hidden_blocks"])
        intermediate_size = collect_arg_in_candidates(vision_config, ["intermediate_size", "mlp_dim"])

        return cls(
            vision_model_name=vision_model_name,
            embed_dim=embed_dim,
            image_size=image_size,
            patch_size=patch_size,
            num_hidden_layers=num_hidden_layers,
            intermediate_size=intermediate_size,
            **kwargs,
        )


class ModernVBertConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a `ModernVBert` model. It is used to
    instantiate a ModernVBert model according to the specified arguments and defines the model architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
    See the documentation for [`PretrainedConfig`] for more details.

    Args:
        text_config (`PretrainedConfig` or `dict`, optional):
            Custom text config or a dict with a `text_model_name` key for the text encoder. If `None`, the
            default text backbone defined by `DEFAULT_TEXT_MODEL_NAME` is used.
        vision_config (`PretrainedConfig` or `dict`, optional):
            Custom vision config or a dict with a `vision_model_name` key for the vision encoder. If `None`, the
            default vision backbone defined by `DEFAULT_VISION_MODEL_NAME` is used.
        image_token_id (`int`, optional, defaults to 128257):
            Token id reserved for image tokens inserted into the text stream.
        vocab_size (`int`, optional, defaults to 128256):
            Vocabulary size used by the text embeddings.
        use_cache (`bool`, optional, defaults to `True`):
            Whether to cache key/value tensors for attention (relevant for decoder architectures).
        tie_word_embeddings (`bool`, optional, defaults to `False`):
            Whether to tie input token embeddings and output token embeddings.
        pixel_shuffle_factor (`int`, optional, defaults to 4):
            Scale factor used by any pixel-shuffle / upsampling operations in the vision head.
        additional_vocab_size (`int`, optional, defaults to 0):
            Number of extra tokens appended to the base vocabulary (useful for adapters / special tokens).
        pad_token_id (`int`, optional):
            Padding token id.
        initializer_range (`float`, optional, defaults to 0.02):
            Stddev used for weight initialization.
        freeze_config (`Any`, optional):
            Optional config describing which submodules to freeze during training.
        use_resampler (`bool`, optional, defaults to `False`):
            Whether to enable an additional resampler on visual features.
        neftune_noise_alpha (`float`, optional, defaults to 0.0):
            Alpha parameter for neftune noise injection.

    Example:
    ```python
    >>> from modernvbert import ModernVBertConfig
    >>> # Initializing configuration
    >>> configuration = ModernVBertConfig()
    >>> # Initializing a model from the configuration (model class is implemented in
    >>> # `modernvbert.modeling_modernvbert`)
    >>> # from modernvbert import ModernVBertModel
    >>> # model = ModernVBertModel(configuration)
    >>> # Accessing the model configuration
    >>> # cfg = model.config
    ```"""

    model_type = "modernvbert"
    is_composition = True

    def __init__(
        self,
        text_config: Union[PretrainedConfig, Dict[str, Any]] = None,
        vision_config: Union[PretrainedConfig, Dict[str, Any]] = None,
        image_token_id: int = 50407,
        vocab_size=50368,
        use_cache=True,
        tie_word_embeddings=False,
        freeze_config=None,
        pad_token_id=None,
        initializer_range=0.02,
        pixel_shuffle_factor=4,
        use_resampler=False,
        additional_vocab_size=0,
        neftune_noise_alpha=0.0,
        **kwargs,
    ):
        self.image_token_id = image_token_id
        self.use_cache = use_cache
        self.tie_word_embeddings = tie_word_embeddings
        self.scale_factor = pixel_shuffle_factor
        self.additional_vocab_size = additional_vocab_size

        if text_config is None:
            base_text_config = AutoConfig.from_pretrained(DEFAULT_TEXT_MODEL_NAME, trust_remote_code=True)
            text_config = ModernVBertTextConfig(base_text_config)
        elif isinstance(text_config, dict):
            text_config = ModernVBertTextConfig.from_dict(text_config)
        self.text_config = text_config

        if vision_config is None:
            base_vision_config = AutoConfig.from_pretrained(DEFAULT_VISION_MODEL_NAME, trust_remote_code=True)
            vision_config = ModernVBertVisionConfig(base_vision_config)
        elif isinstance(vision_config, dict):
            vision_config = ModernVBertVisionConfig.from_dict(vision_config)
        self.vision_config = vision_config

        self.freeze_config = freeze_config
        self.pixel_shuffle_factor = pixel_shuffle_factor
        self.use_resampler = use_resampler
        self.neftune_noise_alpha = neftune_noise_alpha
        self.initializer_range = initializer_range

        hidden_size = kwargs.pop("hidden_size", self.text_config.hidden_size)

        super().__init__(
            **kwargs,
            pad_token_id=pad_token_id,
            tie_word_embeddings=tie_word_embeddings,
            vocab_size=vocab_size,
            hidden_size=hidden_size,
        )

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        output["model_type"] = self.__class__.model_type
        output["vision_config"] = self.vision_config.to_dict()
        output["text_config"] = self.text_config.to_dict()
        return output

    @classmethod
    def from_pretrained_models(
        cls,
        text_model_name: Union[str, os.PathLike],
        vision_model_name: Union[str, os.PathLike],
        **kwargs,
    ) -> "PretrainedConfig":
        text_model_config = ModernVBertTextConfig.from_base_model(text_model_name)
        vision_model_config = ModernVBertVisionConfig.from_base_model(vision_model_name)
        return cls(
            text_config=text_model_config,
            vision_config=vision_model_config,
            **kwargs,
        )
